import sys
sys.path.append('..')

import random
import os

import wandb
import torch
import torch.cuda

import datasets.dataset_parser as aug

from knowledge_tracing.args import ARGS
from knowledge_tracing.network.DKT import DKT
from knowledge_tracing.network.DKVMN import DKVMN
from knowledge_tracing.network.transformer import SAINT, SAKT
from knowledge_tracing.trainer import Trainer
from knowledge_tracing.network.features import CategoricalFeature, CategoricalMultiFeature, PositionalFeature
from knowledge_tracing.network.features import IsCorrect, LossMask, SequenceSize
from knowledge_tracing.dataset import InteractionDataset


# Called at the very first entry of the process
def setup():
    # make file for weight files
    os.makedirs(ARGS.weight_path, exist_ok=True)

    # Setup CUDA
    if ARGS.device == "cuda":
        # MUST set CUDA_VISIBLE_DEVICES before any torch.cuda calls, including `is_available()`
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, ARGS.gpu))
        assert torch.cuda.is_available(), "CUDA is not available"
        # Choose first index of CUDA_VISIBLE_DEVICES as default device
        torch.cuda.set_device(0)

    # Set random seeds
    random.seed(ARGS.random_seed)
    torch.manual_seed(ARGS.random_seed)
    torch.cuda.manual_seed_all(ARGS.random_seed)

    # Initialize wandb session
    if ARGS.use_wandb:
        wandb.init(project=ARGS.project, name=ARGS.name, tags=ARGS.wandb_tags, config=ARGS)


def make_data_parser():
    data_root = ARGS.data_root
    dataset_name = ARGS.dataset_name
    if dataset_name == 'EdNet-KT1':
        item_info_root = ARGS.item_info_root
        sid_mapper_root = ARGS.sid_mapper_root
    else:
        item_info_root = None
        sid_mapper_root = None
    data_parser = aug.Parser(
        data_root=os.path.join(data_root, 'users/'),
        dataset_name=dataset_name,
        item_info_root=item_info_root,
        sid_mapper_root=sid_mapper_root
    )
    return data_parser


def make_features(dataset_name, data_root):
    dataset_constant = aug.Constants(dataset_name, data_root)
    features = {
        'item_idx': CategoricalFeature(dataset_constant.NUM_ITEMS, name='item_idx', padding=0, start_token=None),
        'interaction_idx': CategoricalFeature(dataset_constant.NUM_ITEMS * 2, name='interaction_idx', padding=0, start_token=0),
        'is_correct': IsCorrect(name='is_correct'),
        'tags': CategoricalMultiFeature(dataset_constant.NUM_TAGS, collate=ARGS.collate_fn,
                                        max_num_features=dataset_constant.MAX_NUM_TAGS_PER_ITEM,
                                        name='tags', padding=0, start_token=None),
        'loss_mask':  LossMask(name='loss_mask'),
        'sequence_size': SequenceSize(name='sequence_size'),
        'position': PositionalFeature(name='position', seq_len=ARGS.seq_size)
    }
    enc_features = []
    dec_features = []

    enc_feature_names = ARGS.enc_feature_names
    dec_feature_names = ARGS.dec_feature_names

    enc_feature_dims = ARGS.enc_feature_dims
    dec_feature_dims = ARGS.dec_feature_dims

    for f, d in zip(enc_feature_names, enc_feature_dims):
        enc_features.append((features[f], int(d)))
    for f, d in zip(dec_feature_names, dec_feature_dims):
        dec_features.append((features[f], int(d)))

    all_features = enc_features + dec_features + [(features['loss_mask'], 0), (features['sequence_size'], 0)]
    return enc_features, dec_features, all_features


# Make train/val/test data parser
def make_dataset(data_parser, features, i=1, train_small_rate=1.0):
    # i-th split
    if train_small_rate == 1.0:
        train_users_path = os.path.join(ARGS.data_root, f'train_users_{i}.csv')
    else:
        train_users_path = os.path.join(ARGS.data_root, f'train_users_{i}_{train_small_rate}.csv')
    val_users_path = os.path.join(ARGS.data_root, f'val_users_{i}.csv')
    test_users_path = os.path.join(ARGS.data_root, f'test_users.csv')

    # train data augmentation
    train_data = InteractionDataset(
        data_parser, users=train_users_path,
        features=features, is_training=True, aug_methods=ARGS.augmentations,
        fraction=ARGS.train_data_frac)
    val_data = InteractionDataset(
        data_parser, users=val_users_path,
        features=features, is_training=False,
        fraction=ARGS.val_data_frac)
    test_data = InteractionDataset(
        data_parser, users=test_users_path,
        features=features, is_training=False,
        fraction=ARGS.test_data_frac)

    print(f'Train data size: {len(train_data)}')
    print(f'Val   data size: {  len(val_data)}')
    print(f'Test  data size: { len(test_data)}')

    return train_data, val_data, test_data


def make_model(encoder_features, decoder_features):
    if ARGS.model_type == 'SAINT':
        model = SAINT(device=ARGS.device,
                      encoder_features=encoder_features,
                      decoder_features=decoder_features,
                      N=ARGS.layer_count,
                      d_model=ARGS.d_model_count,
                      h=ARGS.head_count,
                      d_ff=ARGS.d_model_count * 4,
                      dropout=ARGS.dropout_rate)
    elif ARGS.model_type == 'SAKT':
        model = SAKT(device=ARGS.device,
                     encoder_features=encoder_features,
                     d_model=ARGS.d_model_count,
                     h=ARGS.head_count,
                     dropout=ARGS.dropout_rate)
    elif ARGS.model_type == 'DKVMN':
        model = DKVMN(device=ARGS.device,
                      encoder_features=encoder_features,
                      summary_dim=ARGS.d_model_count,
                      concept_num=ARGS.concept_num)
    elif ARGS.model_type in ['DKT', 'qDKT']:
        use_laplacian = False if ARGS.model_type == 'DKT' else True
        model = DKT(device=ARGS.device,
                    encoder_features=encoder_features,
                    hidden_dim=ARGS.d_model_count,
                    dropout=ARGS.dropout_rate,
                    use_laplacian=use_laplacian)

    if ARGS.use_wandb:
        wandb.watch(model)

    return model


def run(model, train_data, val_data, test_data, i):
    # Train_Model
    trainer = Trainer(model, ARGS.device,
                      ARGS.warm_up_step_count, ARGS.d_model_count,
                      ARGS.use_wandb, ARGS.num_epochs, ARGS.num_steps,
                      ARGS.weight_path, ARGS.lr, train_data, val_data, test_data, i)
    trainer.train()
    trainer.test()

    return trainer.test_acc, trainer.test_auc, trainer.max_acc, trainer.max_auc


if __name__ == '__main__':
    setup()
    data_parser = make_data_parser()
    print(f'Number of users: {data_parser.num_users()}')
    some_users = data_parser.all_users()[:10]
    user0 = some_users[0]
    print(f'User 0: {user0}')
    print(f'User 0 Number of interactions: {data_parser.num_interactions(user0)}')

    enc_features, dec_features, all_features = make_features(ARGS.dataset_name, ARGS.data_root)

    train_data, val_data, test_data = make_dataset(data_parser, all_features, ARGS.split_num, ARGS.train_small_rate)
    model = make_model(enc_features, dec_features)
    test_acc, test_auc, max_val_acc, max_val_auc = run(model, train_data, val_data, test_data, ARGS.split_num)

